{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scikit-learn transformer models [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/scikit-learn/transformer/hamilton_notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/scikit-learn/transformer/hamilton_notebook.ipynb)\n",
    "\n",
    "\n",
    "Uncomment and run the cell below if you are in a Google Colab environment. It will:\n",
    "1. Mount google drive. You will be asked to authenticate and give permissions.\n",
    "2. Change directory to google drive.\n",
    "3. Make a directory \"hamilton-tutorials\"\n",
    "4. Change directory to it.\n",
    "5. Clone this repository to your google drive\n",
    "6. Move your current directory to the hello_world example\n",
    "7. Install requirements.\n",
    "\n",
    "This means that any modifications will be saved, and you won't lose them if you close your browser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 1. Mount google drive\n",
    "# from google.colab import drive\n",
    "# drive.mount('/content/drive')\n",
    "## 2. Change directory to google drive.\n",
    "# %cd /content/drive/MyDrive\n",
    "## 3. Make a directory \"hamilton-tutorials\"\n",
    "# !mkdir hamilton-tutorials\n",
    "## 4. Change directory to it.\n",
    "# %cd hamilton-tutorials\n",
    "## 5. Clone this repository to your google drive\n",
    "# !git clone https://github.com/apache/hamilton/\n",
    "## 6. Move your current directory to the hello_world example\n",
    "# %cd hamilton/examples/hello_world\n",
    "## 7. Install requirements.\n",
    "# %pip install -r requirements.txt\n",
    "# clear_output()  # optionally clear outputs\n",
    "# To check your current working directory you can type `!pwd` in a cell and run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 2 - import modules to create part of the DAG from\n",
    "# We use the autoreload extension that comes with ipython to automatically reload modules when\n",
    "# the code in them changes.\n",
    "\n",
    "# import the jupyter extension\n",
    "%load_ext autoreload\n",
    "# set it to only reload the modules imported\n",
    "%autoreload 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import importlib\n",
    "import logging\n",
    "import sys\n",
    "from types import ModuleType\n",
    "from typing import Any, Dict, List\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.base import BaseEstimator, TransformerMixin\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.utils.validation import check_array, check_is_fitted\n",
    "\n",
    "from hamilton import base, driver, log_setup, ad_hoc_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "log_setup.setup_logging()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll place the spend calculations into a new module\n",
    "\n",
    "def avg_3wk_spend(spend: pd.Series) -> pd.Series:\n",
    "    \"\"\"Rolling 3 week average spend.\"\"\"\n",
    "    return spend.rolling(3).mean()\n",
    "\n",
    "\n",
    "def spend_per_signup(spend: pd.Series, signups: pd.Series) -> pd.Series:\n",
    "    \"\"\"The cost per signup in relation to spend.\"\"\"\n",
    "    return spend / signups\n",
    "\n",
    "\n",
    "spend_calculations = ad_hoc_utils.create_temporary_module(\n",
    "    avg_3wk_spend, spend_per_signup, module_name=\"spend_calculations\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll place the spend statistics calculations into a new module\n",
    "\n",
    "def spend_mean(spend: pd.Series) -> float:\n",
    "    \"\"\"Shows function creating a scalar. In this case it computes the mean of the entire column.\"\"\"\n",
    "    return spend.mean()\n",
    "\n",
    "\n",
    "def spend_zero_mean(spend: pd.Series, spend_mean: float) -> pd.Series:\n",
    "    \"\"\"Shows function that takes a scalar. In this case to zero mean spend.\"\"\"\n",
    "    return spend - spend_mean\n",
    "\n",
    "\n",
    "def spend_std_dev(spend: pd.Series) -> float:\n",
    "    \"\"\"Function that computes the standard deviation of the spend column.\"\"\"\n",
    "    return spend.std()\n",
    "\n",
    "\n",
    "def spend_zero_mean_unit_variance(spend_zero_mean: pd.Series, spend_std_dev: float) -> pd.Series:\n",
    "    \"\"\"Function showing one way to make spend have zero mean and unit variance.\"\"\"\n",
    "    return spend_zero_mean / spend_std_dev\n",
    "\n",
    "\n",
    "spend_statistics = ad_hoc_utils.create_temporary_module(\n",
    "    spend_mean, spend_zero_mean, spend_std_dev, spend_zero_mean_unit_variance, module_name=\"spend_statistics\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example we show you a custom scikit-learn `Transformer` class. This class should be compliant with [scikit-learn transformers specifications](https://scikit-learn.org/stable/developers/develop.html). This class is meant to be used as part of broader scikit-learn pipelines. Scikit-learn estimators and pipelines allow for stateful objects, which are helpful when applying transformations on train-test splits notably. Also, all pipeline, estimator, and transformer objects should be picklable, enabling reproducible pipelines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HamiltonTransformer(BaseEstimator, TransformerMixin):\n",
    "    \"\"\"Scikit-learn compatible Transformer implementing Hamilton behavior\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        config: dict = None,\n",
    "        modules: List[ModuleType] = None,\n",
    "        adapter: base.HamiltonGraphAdapter = None,\n",
    "        final_vars: List[str] = None,\n",
    "    ):\n",
    "        self.config = {} if config is None else config\n",
    "        self.modules = [] if modules is None else modules\n",
    "        self.adapter = adapter\n",
    "        self.final_vars = [] if final_vars is None else final_vars\n",
    "\n",
    "    def get_params(self) -> dict:\n",
    "        \"\"\"Get parameters for this estimator.\n",
    "\n",
    "        :return: Current parameters of the estimator\n",
    "        \"\"\"\n",
    "        return {\n",
    "            \"config\": self.config,\n",
    "            \"modules\": self.modules,\n",
    "            \"adapter\": self.adapter,\n",
    "            \"final_vars\": self.final_vars,\n",
    "        }\n",
    "\n",
    "    def set_params(self, **parameters) -> HamiltonTransformer:\n",
    "        \"\"\"Get parameters for this estimator.\n",
    "\n",
    "        :param parameters: Estimator parameters.\n",
    "        :return: self\n",
    "        \"\"\"\n",
    "        for parameter, value in parameters.items():\n",
    "            setattr(self, parameter, value)\n",
    "        return self\n",
    "\n",
    "    def get_features_names_out(self):\n",
    "        \"\"\"\"\"\"\n",
    "        if self.feature_names_out_:\n",
    "            return self.feature_names_out_\n",
    "\n",
    "    def _get_tags(self) -> dict:\n",
    "        \"\"\"Get scikit-learn compatible estimator tags for introspection\n",
    "\n",
    "        ref: https://scikit-learn.org/stable/developers/develop.html#estimator-tags\n",
    "        \"\"\"\n",
    "        return {\"requires_fit\": True, \"requires_y\": False}\n",
    "\n",
    "    def fit(self, X, y=None, overrides: Dict[str, Any] = None) -> HamiltonTransformer:\n",
    "        \"\"\"Instantiate Hamilton driver.Driver object\n",
    "\n",
    "        :param X: Input 2D array\n",
    "        :param overrides: dictionary of override values passed to driver.execute() during .transform()\n",
    "        :return: self\n",
    "        \"\"\"\n",
    "\n",
    "        check_array(X, accept_sparse=True)\n",
    "        self.overrides_ = {} if overrides is None else overrides\n",
    "\n",
    "        self.driver_ = driver.Driver(self.config, *self.modules, adapter=self.adapter)\n",
    "        self.n_features_in_: int = X.shape[1]\n",
    "\n",
    "        return self\n",
    "\n",
    "    def transform(self, X, y=None, **kwargs) -> pd.DataFrame:\n",
    "        \"\"\"Execute Hamilton Driver on X with optional parameters fit_params and returns a\n",
    "        transformed version of X. Requires prior call to .fit() to instantiate Hamilton Driver\n",
    "\n",
    "        :param X: Input 2D array\n",
    "        :return: Hamilton Driver output 2D array\n",
    "        \"\"\"\n",
    "\n",
    "        check_is_fitted(self, \"n_features_in_\")\n",
    "\n",
    "        if isinstance(X, pd.DataFrame):\n",
    "            check_array(X, accept_sparse=True)\n",
    "            if X.shape[1] != self.n_features_in_:\n",
    "                raise ValueError(\"Shape of input is different from what was seen in `fit`\")\n",
    "\n",
    "            X = X.to_dict(orient=\"series\")\n",
    "\n",
    "        X_t = self.driver_.execute(final_vars=self.final_vars, overrides=self.overrides_, inputs=X)\n",
    "        # self.driver_.visualize_execution(final_vars=self.final_vars,\n",
    "        #                                  output_file_path=\"./scikit_transformer\",\n",
    "        #                                  render_kwargs={\"format\": \"png\"},\n",
    "        #                                  inputs=X)\n",
    "        self.n_features_out_ = len(self.final_vars)\n",
    "        self.feature_names_out_ = X_t.columns.to_list()\n",
    "        return X_t\n",
    "\n",
    "    def fit_transform(self, X, y=None, **fit_params) -> pd.DataFrame:\n",
    "        \"\"\"Execute Hamilton Driver on X with optional parameters fit_params and returns a\n",
    "        transformed version of X.\n",
    "\n",
    "        :param X: Input 2D array\n",
    "        :return: Hamilton Driver output 2D array\n",
    "        \"\"\"\n",
    "        return self.fit(X, **fit_params).transform(X)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:hamilton.telemetry:Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n"
     ]
    }
   ],
   "source": [
    "# Set up the driver, input and output columns\n",
    "initial_df = pd.DataFrame(\n",
    "    {\"signups\": [1, 10, 50, 100, 200, 400], \"spend\": [10, 10, 20, 40, 40, 50]}\n",
    ")\n",
    "\n",
    "output_columns = [\n",
    "    \"spend\",\n",
    "    \"signups\",\n",
    "    \"avg_3wk_spend\",\n",
    "    \"spend_per_signup\",\n",
    "    \"spend_zero_mean_unit_variance\",\n",
    "]\n",
    "\n",
    "\n",
    "dr = driver.Driver({}, spend_calculations,spend_statistics)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"610pt\" height=\"260pt\"\n",
       " viewBox=\"0.00 0.00 610.43 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 606.43,-256 606.43,4 -4,4\"/>\n",
       "<!-- spend_per_signup -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>spend_per_signup</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"94.24\" cy=\"-162\" rx=\"94.48\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"94.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_per_signup</text>\n",
       "</g>\n",
       "<!-- spend_mean -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>spend_mean</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"274.24\" cy=\"-162\" rx=\"68.49\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"274.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_mean</text>\n",
       "</g>\n",
       "<!-- spend_zero_mean -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>spend_zero_mean</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"274.24\" cy=\"-90\" rx=\"92.88\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"274.24\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean</text>\n",
       "</g>\n",
       "<!-- spend_mean&#45;&gt;spend_zero_mean -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>spend_mean&#45;&gt;spend_zero_mean</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M274.24,-143.7C274.24,-135.98 274.24,-126.71 274.24,-118.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"277.74,-118.1 274.24,-108.1 270.74,-118.1 277.74,-118.1\"/>\n",
       "</g>\n",
       "<!-- spend -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>spend</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"370.24\" cy=\"-234\" rx=\"69.59\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"370.24\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Input: spend</text>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_per_signup -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_per_signup</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M322.48,-220.89C277.06,-209.37 208.42,-191.96 158.28,-179.24\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"159.06,-175.83 148.51,-176.76 157.34,-182.61 159.06,-175.83\"/>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_mean -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_mean</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M347.98,-216.76C335.13,-207.4 318.8,-195.49 304.77,-185.26\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"306.53,-182.21 296.39,-179.15 302.4,-187.87 306.53,-182.21\"/>\n",
       "</g>\n",
       "<!-- spend_std_dev -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>spend_std_dev</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"464.24\" cy=\"-90\" rx=\"78.79\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"464.24\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_std_dev</text>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_std_dev -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_std_dev</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M380.76,-216.14C392.05,-198.14 410.61,-168.86 427.24,-144 433.32,-134.91 440.14,-125.05 446.25,-116.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"449.18,-118.26 452.07,-108.07 443.45,-114.23 449.18,-118.26\"/>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_zero_mean -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_zero_mean</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M370.25,-215.82C369.48,-196.75 365.81,-165.81 351.24,-144 342.67,-131.17 329.85,-120.56 317.11,-112.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"318.72,-109.18 308.36,-106.95 315.07,-115.15 318.72,-109.18\"/>\n",
       "</g>\n",
       "<!-- avg_3wk_spend -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>avg_3wk_spend</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"519.24\" cy=\"-162\" rx=\"83.39\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"519.24\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">avg_3wk_spend</text>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;avg_3wk_spend -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>spend&#45;&gt;avg_3wk_spend</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M402.21,-217.98C424.02,-207.73 453.16,-194.05 476.9,-182.89\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"478.48,-186.02 486.04,-178.6 475.5,-179.68 478.48,-186.02\"/>\n",
       "</g>\n",
       "<!-- spend_zero_mean_unit_variance -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>spend_zero_mean_unit_variance</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"369.24\" cy=\"-18\" rx=\"159.47\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"369.24\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean_unit_variance</text>\n",
       "</g>\n",
       "<!-- spend_std_dev&#45;&gt;spend_zero_mean_unit_variance -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>spend_std_dev&#45;&gt;spend_zero_mean_unit_variance</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M441.73,-72.41C429.37,-63.3 413.85,-51.87 400.35,-41.92\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"402.38,-39.07 392.25,-35.96 398.23,-44.71 402.38,-39.07\"/>\n",
       "</g>\n",
       "<!-- spend_zero_mean&#45;&gt;spend_zero_mean_unit_variance -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>spend_zero_mean&#45;&gt;spend_zero_mean_unit_variance</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M296.76,-72.41C309.11,-63.3 324.63,-51.87 338.13,-41.92\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"340.26,-44.71 346.23,-35.96 336.1,-39.07 340.26,-44.71\"/>\n",
       "</g>\n",
       "<!-- signups -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>signups</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"94.24\" cy=\"-234\" rx=\"77.19\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"94.24\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">Input: signups</text>\n",
       "</g>\n",
       "<!-- signups&#45;&gt;spend_per_signup -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>signups&#45;&gt;spend_per_signup</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M94.24,-215.7C94.24,-207.98 94.24,-198.71 94.24,-190.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"97.74,-190.1 94.24,-180.1 90.74,-190.1 97.74,-190.1\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fc5dc81e550>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Visualize execution\n",
    "# To visualize do `pip install \"sf-hamilton[visualization]\"` if you want these to work\n",
    "\n",
    "# visualize all possible functions\n",
    "dr.display_all_functions(output_file_path=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check 1: output of `vanilla driver` == `custom transformer`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "hamilton_df = dr.execute(final_vars=output_columns, inputs=initial_df.to_dict(orient=\"series\"))\n",
    "\n",
    "custom_transformer = HamiltonTransformer(\n",
    "    config={}, modules=[spend_calculations, spend_statistics], final_vars=output_columns)\n",
    "sklearn_df = custom_transformer.fit_transform(initial_df)\n",
    "\n",
    "try:\n",
    "    pd.testing.assert_frame_equal(sklearn_df, hamilton_df)\n",
    "\n",
    "except ValueError as e:\n",
    "    logger.warning(\"Check 1 failed; `sklearn_df` and `hamilton_df` are unequal\")\n",
    "    raise e\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check 2: output of `vanilla driver >> transformation` == `scikit-learn pipeline`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "\n",
    "hamilton_df = dr.execute(final_vars=output_columns, inputs=initial_df.to_dict(orient=\"series\"))\n",
    "hamilton_then_sklearn = scaler.fit_transform(hamilton_df)\n",
    "\n",
    "pipeline1 = Pipeline(steps=[(\"hamilton\", custom_transformer), (\"scaler\", scaler)])\n",
    "pipe_custom_then_sklearn = pipeline1.fit_transform(initial_df)\n",
    "try:\n",
    "    assert isinstance(hamilton_then_sklearn, np.ndarray)\n",
    "    assert isinstance(pipe_custom_then_sklearn, np.ndarray)\n",
    "\n",
    "    np.testing.assert_equal(pipe_custom_then_sklearn, hamilton_then_sklearn)\n",
    "\n",
    "except ValueError as e:\n",
    "    logger.warning(\n",
    "        \"Check 2 failed; `pipe_custom_then_sklearn` and `hamilton_then_sklearn` are unequal\"\n",
    "    )\n",
    "    raise e\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check 3: output of `transformation >> vanilla driver` == `scikit-learn pipeline`\n",
    "The custom transformer requires a DataFrame, we leverage the `.set_output` from scikit-learn v1.2\n",
    "ref: https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "scaler = StandardScaler().set_output(transform=\"pandas\")\n",
    "\n",
    "scaled_df = scaler.fit_transform(initial_df)\n",
    "sklearn_then_hamilton = dr.execute(\n",
    "    final_vars=output_columns, inputs=scaled_df.to_dict(orient=\"series\")\n",
    ")\n",
    "\n",
    "pipeline2 = Pipeline(steps=[(\"scaler\", scaler), (\"hamilton\", custom_transformer)])\n",
    "pipe_sklearn_then_custom = pipeline2.fit_transform(initial_df)\n",
    "\n",
    "try:\n",
    "    assert isinstance(sklearn_then_hamilton, pd.DataFrame)\n",
    "    assert isinstance(pipe_sklearn_then_custom, pd.DataFrame)\n",
    "\n",
    "    pd.testing.assert_frame_equal(pipe_sklearn_then_custom, sklearn_then_hamilton)\n",
    "except ValueError as e:\n",
    "    logger.warning(\n",
    "        \"Check 3 failed; `pipe_sklearn_then_custom` and `sklearn_then_hamilton` are unequal\"\n",
    "    )\n",
    "    raise e\n",
    "\n",
    "logger.info(\"All checks passed. `HamiltonTransformer` behaves properly\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before continuing with using hamilton with scikit-learn please be aware of its possible limitations [here](https://github.com/apache/hamilton/tree/main/examples/scikit-learn#limitations-and-todos)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
